
import numpy as np
import torch
import torch.nn.functional as F


class RetMetric(object):
    def __init__(self, feats, labels):

        if len(feats) == 2 and type(feats) == list:
            """
            feats = [gallery_feats, query_feats]
            labels = [gallery_labels, query_labels]
            """
            self.is_equal_query = False

            self.gallery_feats, self.query_feats = feats
            self.gallery_labels, self.query_labels = labels

        else:
            self.is_equal_query = True
            self.gallery_feats = self.query_feats = feats
            self.gallery_labels = self.query_labels = labels

        self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))

    def recall_k(self, k=1):
        m = len(self.sim_mat)

        match_counter = 0

        for i in range(m):
            pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
            neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]

            thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)

            if np.sum(neg_sim > thresh) < k:
                match_counter += 1
        return float(match_counter) / m

class RetMetricMap(object):
    def __init__(self):
        self.ok = 1
    def get_metrics(self, sim, query_label, gallery_label):
        tops = torch.argsort(sim, descending=True)

        top1 = tops[0]
        r1 = 0.0
        if query_label == gallery_label[top1]:
            r1 = 1.0

        num_pos = torch.sum(gallery_label == query_label).item()

        rp = torch.sum(gallery_label[tops[0:num_pos]] == query_label).float() / float(num_pos)

        equality = gallery_label[tops[0:num_pos]] == query_label
        equality = equality.float()
        cumulative_correct = torch.cumsum(equality, dim=0)
        k_idx = torch.arange(num_pos) + 1
        precision_at_ks = (cumulative_correct * equality) / k_idx

        rp = rp.item()
        mapr = torch.mean(precision_at_ks).item()

        return r1, rp, mapr

    def get_metrics_rank(self, tops, query_label, gallery_label):

        # tops = torch.argsort(sim, descending=True)
        top1 = tops[0]
        r1 = 0.0
        if query_label == gallery_label[top1]:
            r1 = 1.0

        num_pos = torch.sum(gallery_label == query_label).item()

        rp = torch.sum(gallery_label[tops[0:num_pos]] == query_label).float() / float(num_pos)

        equality = gallery_label[tops[0:num_pos]] == query_label
        equality = equality.float()
        cumulative_correct = torch.cumsum(equality, dim=0)
        k_idx = torch.arange(num_pos) + 1
        precision_at_ks = (cumulative_correct * equality) / k_idx

        rp = rp.item()
        mapr = torch.mean(precision_at_ks).item()

        return r1, rp, mapr

    def Sinkhorn(self,K, u, v):
        r = torch.ones_like(u)
        c = torch.ones_like(v)
        thresh = 1e-1
        for _ in range(100):
            r0 = r
            r = u / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1)
            c = v / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1)
            err = (r - r0).abs().mean()
            if err.item() < thresh:
                break
        T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K
        return T

    def calc_similarity(self, anchor, anchor_center, fb, fb_center, stage, use_uniform=False):
        if stage == 0:
            sim = torch.einsum('c,nc->n', anchor_center, fb_center)
        else:
            N, _, R = fb.size()

            sim = torch.einsum('cm,ncs->nsm', anchor, fb).contiguous().view(N, R, R)
            dis = 1.0 - sim
            K = torch.exp(-dis / 0.05)

            if use_uniform:
                u = torch.zeros(N, R, dtype=sim.dtype, device=sim.device).fill_(1. / R)
                v = torch.zeros(N, R, dtype=sim.dtype, device=sim.device).fill_(1. / R)
            else:
                att = F.relu(torch.einsum("c,ncr->nr", anchor_center, fb)).view(N, R)
                u = att / (att.sum(dim=1, keepdims=True) + 1e-5)

                att = F.relu(torch.einsum("cr,nc->nr", anchor, fb_center)).view(N, R)
                v = att / (att.sum(dim=1, keepdims=True) + 1e-5)

            T = self.Sinkhorn(K, u, v)
            sim = torch.sum(T * sim, dim=(1, 2))
        return sim


